import torch
import torch.nn as nn
import torchvision
import time
import numpy as np
import random
from hyper_params import hp

import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image

class myencoder(nn.Module):
    def __init__(self, hps, fin=8):
        super(myencoder, self).__init__()
        self.hps = hps
        self.fin = fin
        self.downsampling1 = nn.Sequential(nn.Conv2d(1, fin * 2, 7, stride=2, padding=3), nn.InstanceNorm2d(fin*2),
                                           nn.ReLU(True))
        self.downsampling2 = nn.Sequential(nn.Conv2d(fin * 2, fin * 4, 5, stride=2, padding=2),nn.InstanceNorm2d(fin*4),
                                           nn.ReLU(True))
        self.downsampling3 = nn.Sequential(nn.Conv2d(fin * 4, fin * 8, 3, stride=2, padding=1),nn.InstanceNorm2d(fin*8),
                                           nn.ReLU(True))
        self.downsampling4 = nn.Sequential(nn.Conv2d(fin * 8, fin * 16, 3, stride=2, padding=1),nn.InstanceNorm2d(fin*16),
                                           nn.ReLU(True))
        self.downsampling5 = nn.Sequential(nn.Conv2d(fin * 16, fin * 32, 3, stride=2, padding=1),
                                           nn.InstanceNorm2d(fin * 32),
                                           nn.ReLU(True))
        self.downsampling6 = nn.Sequential(nn.Conv2d(fin * 32, fin * 64, 3, stride=2, padding=1),
                                           nn.InstanceNorm2d(fin * 64),
                                           nn.ReLU(True))

        self.pooling = nn.AdaptiveAvgPool2d(1)

        self.upsampling6 = nn.Sequential(nn.ConvTranspose2d(fin * 64, fin * 32, 3, 2, padding=1, output_padding=1),
                                         nn.InstanceNorm2d(fin * 32),
                                         nn.ReLU(True))

        self.upsampling5 = nn.Sequential(nn.ConvTranspose2d(fin * 32, fin * 16, 3, 2, padding=1, output_padding=1),
                                         nn.InstanceNorm2d(fin * 16),
                                         nn.ReLU(True))

        self.upsampling4 = nn.Sequential(nn.ConvTranspose2d(fin * 16, fin * 8, 3, 2, padding=1, output_padding=1), nn.InstanceNorm2d(fin*8),
                                         nn.ReLU(True))
        self.upsampling3 = nn.Sequential(nn.ConvTranspose2d(fin * 8, fin * 4, 3, 2, padding=1, output_padding=1), nn.InstanceNorm2d(fin*4)
                                         , nn.ReLU(True))

        self.upsampling2 = nn.Sequential(nn.ConvTranspose2d(fin * 4, fin * 2, 5, 2, padding=2, output_padding=1), nn.InstanceNorm2d(fin*2)
                , nn.ReLU(True))

        self.upsampling1 = nn.Sequential(nn.ConvTranspose2d(fin * 2, 1, 7, 2, padding=3, output_padding=1))

        self.fc_mu = nn.Linear(512, self.hps.Nz)
        self.fc_sigma = nn.Linear(512, self.hps.Nz)
        self.criterionL1 = nn.L1Loss()
        self.criterionL2 = nn.MSELoss()

    def forward(self, graph):
        gt = graph[:, 1].view(graph.shape[0],1,graph.shape[-1],graph.shape[-1])
        shape = gt.shape
        image = graph[:,0].view(graph.shape[0],1,graph.shape[-1],graph.shape[-1])

        top,  bottom = self.get_feature(image)
        gt_top,  _ = self.get_feature(gt)


        latent_code = self.pooling(top).view(shape[0],512)
        #latent_code = top.view(shape[0],512*4)

        mu = self.fc_mu(latent_code)
        sigma = self.fc_sigma(latent_code)
        sigma_e = torch.exp(sigma / 2.)
        z_size = mu.size()
        if mu.get_device() != -1:  # not in cpu
            n = torch.normal(torch.zeros(z_size), torch.ones(z_size)).cuda(mu.get_device())
        else:  # in cpu
            n = torch.normal(torch.zeros(z_size), torch.ones(z_size))
        # sample z
        z = mu + sigma_e * n

        return z, mu, sigma, latent_code, self.criterionL1(top, gt_top), self.criterionL2(bottom, gt)

    def get_feature(self, image):
        x1 = self.downsampling1(image)
        x2 = self.downsampling2(x1)
        x3 = self.downsampling3(x2)
        x4 = self.downsampling4(x3)
        x5 = self.downsampling5(x4)
        x6 = self.downsampling6(x5)

        y5 = 0.5*self.upsampling6(x6) + 0.5*x5
        y4 = 0.5*self.upsampling5(y5) + 0.5*x4
        y3 = 0.5*self.upsampling4(y4) + 0.5*x3
        y2 = 0.5*self.upsampling3(y3) + 0.5*x2
        y1 = 0.5*self.upsampling2(y2) + 0.5*x1
        y0 = self.upsampling1(y1)
        y0 = torch.tanh(y0)

        '''
        loop 2
        '''
        for i in range(2):
            x1 = 0.5*self.downsampling1(y0)+0.5*y1
            x2 = 0.5*self.downsampling2(x1)+0.5*y2
            x3 = 0.5*self.downsampling3(x2)+0.5*y3
            x4 = 0.5*self.downsampling4(x3)+0.5*y4
            x5 = 0.5*self.downsampling5(x4)+0.5*y5
            x6 = self.downsampling6(x5)

            y5 = 0.5 * self.upsampling6(x6) + 0.5 * x5
            y4 = 0.5 * self.upsampling5(y5) + 0.5 * x4
            y3 = 0.5 * self.upsampling4(y4) + 0.5 * x3
            y2 = 0.5 * self.upsampling3(y3) + 0.5 * x2
            y1 = 0.5 * self.upsampling2(y2) + 0.5 * x1
            y0 = self.upsampling1(y1)
            y0 = torch.tanh(y0)
        return x6, y0
